"""
BSD 3-Clause License

Copyright (c) Soumith Chintala 2016,
All rights reserved.

Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:

* Redistributions of source code must retain the above copyright notice, this
  list of conditions and the following disclaimer.

* Redistributions in binary form must reproduce the above copyright notice,
  this list of conditions and the following disclaimer in the documentation
  and/or other materials provided with the distribution.

* Neither the name of the copyright holder nor the names of its
  contributors may be used to endorse or promote products derived from
  this software without specific prior written permission.

THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.


Copyright 2020 Huawei Technologies Co., Ltd

Licensed under the BSD 3-Clause License (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

https://spdx.org/licenses/BSD-3-Clause.html

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

from collections import OrderedDict  # pylint: disable=g-importing-member

import torch
import torch.nn as nn
from functools import partial

from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
from .helpers import build_model_with_cfg, named_apply, adapt_input_conv
from .registry import register_model
from .layers import GroupNormAct, BatchNormAct2d, EvoNormBatch2d, EvoNormSample2d,\
    ClassifierHead, DropPath, AvgPool2dSame, create_pool2d, StdConv2d, create_conv2d


def _cfg(url='', **kwargs):
    return {
        'url': url,
        'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
        'crop_pct': 0.875, 'interpolation': 'bilinear',
        'mean': IMAGENET_INCEPTION_MEAN, 'std': IMAGENET_INCEPTION_STD,
        'first_conv': 'stem.conv', 'classifier': 'head.fc',
        **kwargs
    }


default_cfgs = {
    # pretrained on imagenet21k, finetuned on imagenet1k
    'resnetv2_50x1_bitm': _cfg(
        url='https://storage.googleapis.com/bit_models/BiT-M-R50x1-ILSVRC2012.npz',
        input_size=(3, 448, 448), pool_size=(14, 14), crop_pct=1.0),
    'resnetv2_50x3_bitm': _cfg(
        url='https://storage.googleapis.com/bit_models/BiT-M-R50x3-ILSVRC2012.npz',
        input_size=(3, 448, 448), pool_size=(14, 14), crop_pct=1.0),
    'resnetv2_101x1_bitm': _cfg(
        url='https://storage.googleapis.com/bit_models/BiT-M-R101x1-ILSVRC2012.npz',
        input_size=(3, 448, 448), pool_size=(14, 14), crop_pct=1.0),
    'resnetv2_101x3_bitm': _cfg(
        url='https://storage.googleapis.com/bit_models/BiT-M-R101x3-ILSVRC2012.npz',
        input_size=(3, 448, 448), pool_size=(14, 14), crop_pct=1.0),
    'resnetv2_152x2_bitm': _cfg(
        url='https://storage.googleapis.com/bit_models/BiT-M-R152x2-ILSVRC2012.npz',
        input_size=(3, 448, 448), pool_size=(14, 14), crop_pct=1.0),
    'resnetv2_152x4_bitm': _cfg(
        url='https://storage.googleapis.com/bit_models/BiT-M-R152x4-ILSVRC2012.npz',
        input_size=(3, 480, 480), pool_size=(15, 15), crop_pct=1.0),  # only one at 480x480?

    # trained on imagenet-21k
    'resnetv2_50x1_bitm_in21k': _cfg(
        url='https://storage.googleapis.com/bit_models/BiT-M-R50x1.npz',
        num_classes=21843),
    'resnetv2_50x3_bitm_in21k': _cfg(
        url='https://storage.googleapis.com/bit_models/BiT-M-R50x3.npz',
        num_classes=21843),
    'resnetv2_101x1_bitm_in21k': _cfg(
        url='https://storage.googleapis.com/bit_models/BiT-M-R101x1.npz',
        num_classes=21843),
    'resnetv2_101x3_bitm_in21k': _cfg(
        url='https://storage.googleapis.com/bit_models/BiT-M-R101x3.npz',
        num_classes=21843),
    'resnetv2_152x2_bitm_in21k': _cfg(
        url='https://storage.googleapis.com/bit_models/BiT-M-R152x2.npz',
        num_classes=21843),
    'resnetv2_152x4_bitm_in21k': _cfg(
        url='https://storage.googleapis.com/bit_models/BiT-M-R152x4.npz',
        num_classes=21843),

    'resnetv2_50x1_bit_distilled': _cfg(
        url='https://storage.googleapis.com/bit_models/distill/R50x1_224.npz',
        interpolation='bicubic'),
    'resnetv2_152x2_bit_teacher': _cfg(
        url='https://storage.googleapis.com/bit_models/distill/R152x2_T_224.npz',
        interpolation='bicubic'),
    'resnetv2_152x2_bit_teacher_384': _cfg(
        url='https://storage.googleapis.com/bit_models/distill/R152x2_T_384.npz',
        input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, interpolation='bicubic'),

    'resnetv2_50': _cfg(
        interpolation='bicubic'),
    'resnetv2_50d': _cfg(
        interpolation='bicubic', first_conv='stem.conv1'),
    'resnetv2_50t': _cfg(
        interpolation='bicubic', first_conv='stem.conv1'),
    'resnetv2_101': _cfg(
        interpolation='bicubic'),
    'resnetv2_101d': _cfg(
        interpolation='bicubic', first_conv='stem.conv1'),
    'resnetv2_152': _cfg(
        interpolation='bicubic'),
    'resnetv2_152d': _cfg(
        interpolation='bicubic', first_conv='stem.conv1'),
}


def make_div(v, divisor=8):
    min_value = divisor
    new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
    if new_v < 0.9 * v:
        new_v += divisor
    return new_v


class PreActBottleneck(nn.Module):
    """Pre-activation (v2) bottleneck block.

    Follows the implementation of "Identity Mappings in Deep Residual Networks":
    https://github.com/KaimingHe/resnet-1k-layers/blob/master/resnet-pre-act.lua

    Except it puts the stride on 3x3 conv when available.
    """

    def __init__(
            self, in_chs, out_chs=None, bottle_ratio=0.25, stride=1, dilation=1, first_dilation=None, groups=1,
            act_layer=None, conv_layer=None, norm_layer=None, proj_layer=None, drop_path_rate=0.):
        super().__init__()
        first_dilation = first_dilation or dilation
        conv_layer = conv_layer or StdConv2d
        norm_layer = norm_layer or partial(GroupNormAct, num_groups=32)
        out_chs = out_chs or in_chs
        mid_chs = make_div(out_chs * bottle_ratio)

        if proj_layer is not None:
            self.downsample = proj_layer(
                in_chs, out_chs, stride=stride, dilation=dilation, first_dilation=first_dilation, preact=True,
                conv_layer=conv_layer, norm_layer=norm_layer)
        else:
            self.downsample = None

        self.norm1 = norm_layer(in_chs)
        self.conv1 = conv_layer(in_chs, mid_chs, 1)
        self.norm2 = norm_layer(mid_chs)
        self.conv2 = conv_layer(mid_chs, mid_chs, 3, stride=stride, dilation=first_dilation, groups=groups)
        self.norm3 = norm_layer(mid_chs)
        self.conv3 = conv_layer(mid_chs, out_chs, 1)
        self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0 else nn.Identity()

    def zero_init_last(self):
        nn.init.zeros_(self.conv3.weight)

    def forward(self, x):
        x_preact = self.norm1(x)

        # shortcut branch
        shortcut = x
        if self.downsample is not None:
            shortcut = self.downsample(x_preact)

        # residual branch
        x = self.conv1(x_preact)
        x = self.conv2(self.norm2(x))
        x = self.conv3(self.norm3(x))
        x = self.drop_path(x)
        return x + shortcut


class Bottleneck(nn.Module):
    """Non Pre-activation bottleneck block, equiv to V1.5/V1b Bottleneck. Used for ViT.
    """
    def __init__(
            self, in_chs, out_chs=None, bottle_ratio=0.25, stride=1, dilation=1, first_dilation=None, groups=1,
            act_layer=None, conv_layer=None, norm_layer=None, proj_layer=None, drop_path_rate=0.):
        super().__init__()
        first_dilation = first_dilation or dilation
        act_layer = act_layer or nn.ReLU
        conv_layer = conv_layer or StdConv2d
        norm_layer = norm_layer or partial(GroupNormAct, num_groups=32)
        out_chs = out_chs or in_chs
        mid_chs = make_div(out_chs * bottle_ratio)

        if proj_layer is not None:
            self.downsample = proj_layer(
                in_chs, out_chs, stride=stride, dilation=dilation, preact=False,
                conv_layer=conv_layer, norm_layer=norm_layer)
        else:
            self.downsample = None

        self.conv1 = conv_layer(in_chs, mid_chs, 1)
        self.norm1 = norm_layer(mid_chs)
        self.conv2 = conv_layer(mid_chs, mid_chs, 3, stride=stride, dilation=first_dilation, groups=groups)
        self.norm2 = norm_layer(mid_chs)
        self.conv3 = conv_layer(mid_chs, out_chs, 1)
        self.norm3 = norm_layer(out_chs, apply_act=False)
        self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0 else nn.Identity()
        self.act3 = act_layer(inplace=True)

    def zero_init_last(self):
        nn.init.zeros_(self.norm3.weight)

    def forward(self, x):
        # shortcut branch
        shortcut = x
        if self.downsample is not None:
            shortcut = self.downsample(x)

        # residual
        x = self.conv1(x)
        x = self.norm1(x)
        x = self.conv2(x)
        x = self.norm2(x)
        x = self.conv3(x)
        x = self.norm3(x)
        x = self.drop_path(x)
        x = self.act3(x + shortcut)
        return x


class DownsampleConv(nn.Module):
    def __init__(
            self, in_chs, out_chs, stride=1, dilation=1, first_dilation=None, preact=True,
            conv_layer=None, norm_layer=None):
        super(DownsampleConv, self).__init__()
        self.conv = conv_layer(in_chs, out_chs, 1, stride=stride)
        self.norm = nn.Identity() if preact else norm_layer(out_chs, apply_act=False)

    def forward(self, x):
        return self.norm(self.conv(x))


class DownsampleAvg(nn.Module):
    def __init__(
            self, in_chs, out_chs, stride=1, dilation=1, first_dilation=None,
            preact=True, conv_layer=None, norm_layer=None):
        """ AvgPool Downsampling as in 'D' ResNet variants. This is not in RegNet space but I might experiment."""
        super(DownsampleAvg, self).__init__()
        avg_stride = stride if dilation == 1 else 1
        if stride > 1 or dilation > 1:
            avg_pool_fn = AvgPool2dSame if avg_stride == 1 and dilation > 1 else nn.AvgPool2d
            self.pool = avg_pool_fn(2, avg_stride, ceil_mode=True, count_include_pad=False)
        else:
            self.pool = nn.Identity()
        self.conv = conv_layer(in_chs, out_chs, 1, stride=1)
        self.norm = nn.Identity() if preact else norm_layer(out_chs, apply_act=False)

    def forward(self, x):
        return self.norm(self.conv(self.pool(x)))


class ResNetStage(nn.Module):
    """ResNet Stage."""
    def __init__(self, in_chs, out_chs, stride, dilation, depth, bottle_ratio=0.25, groups=1,
                 avg_down=False, block_dpr=None, block_fn=PreActBottleneck,
                 act_layer=None, conv_layer=None, norm_layer=None, **block_kwargs):
        super(ResNetStage, self).__init__()
        first_dilation = 1 if dilation in (1, 2) else 2
        layer_kwargs = dict(act_layer=act_layer, conv_layer=conv_layer, norm_layer=norm_layer)
        proj_layer = DownsampleAvg if avg_down else DownsampleConv
        prev_chs = in_chs
        self.blocks = nn.Sequential()
        for block_idx in range(depth):
            drop_path_rate = block_dpr[block_idx] if block_dpr else 0.
            stride = stride if block_idx == 0 else 1
            self.blocks.add_module(str(block_idx), block_fn(
                prev_chs, out_chs, stride=stride, dilation=dilation, bottle_ratio=bottle_ratio, groups=groups,
                first_dilation=first_dilation, proj_layer=proj_layer, drop_path_rate=drop_path_rate,
                **layer_kwargs, **block_kwargs))
            prev_chs = out_chs
            first_dilation = dilation
            proj_layer = None

    def forward(self, x):
        x = self.blocks(x)
        return x


def is_stem_deep(stem_type):
    return any([s in stem_type for s in ('deep', 'tiered')])


def create_resnetv2_stem(
        in_chs, out_chs=64, stem_type='', preact=True,
        conv_layer=StdConv2d, norm_layer=partial(GroupNormAct, num_groups=32)):
    stem = OrderedDict()
    assert stem_type in ('', 'fixed', 'same', 'deep', 'deep_fixed', 'deep_same', 'tiered')

    # NOTE conv padding mode can be changed by overriding the conv_layer def
    if is_stem_deep(stem_type):
        # A 3 deep 3x3  conv stack as in ResNet V1D models
        if 'tiered' in stem_type:
            stem_chs = (3 * out_chs // 8, out_chs // 2)  # 'T' resnets in resnet.py
        else:
            stem_chs = (out_chs // 2, out_chs // 2)  # 'D' ResNets
        stem['conv1'] = conv_layer(in_chs, stem_chs[0], kernel_size=3, stride=2)
        stem['norm1'] = norm_layer(stem_chs[0])
        stem['conv2'] = conv_layer(stem_chs[0], stem_chs[1], kernel_size=3, stride=1)
        stem['norm2'] = norm_layer(stem_chs[1])
        stem['conv3'] = conv_layer(stem_chs[1], out_chs, kernel_size=3, stride=1)
        if not preact:
            stem['norm3'] = norm_layer(out_chs)
    else:
        # The usual 7x7 stem conv
        stem['conv'] = conv_layer(in_chs, out_chs, kernel_size=7, stride=2)
        if not preact:
            stem['norm'] = norm_layer(out_chs)

    if 'fixed' in stem_type:
        # 'fixed' SAME padding approximation that is used in BiT models
        stem['pad'] = nn.ConstantPad2d(1, 0.)
        stem['pool'] = nn.MaxPool2d(kernel_size=3, stride=2, padding=0)
    elif 'same' in stem_type:
        # full, input size based 'SAME' padding, used in ViT Hybrid model
        stem['pool'] = create_pool2d('max', kernel_size=3, stride=2, padding='same')
    else:
        # the usual PyTorch symmetric padding
        stem['pool'] = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

    return nn.Sequential(stem)


class ResNetV2(nn.Module):
    """Implementation of Pre-activation (v2) ResNet mode.
    """

    def __init__(
            self, layers, channels=(256, 512, 1024, 2048),
            num_classes=1000, in_chans=3, global_pool='avg', output_stride=32,
            width_factor=1, stem_chs=64, stem_type='', avg_down=False, preact=True,
            act_layer=nn.ReLU, conv_layer=StdConv2d, norm_layer=partial(GroupNormAct, num_groups=32),
            drop_rate=0., drop_path_rate=0., zero_init_last=True):
        super().__init__()
        self.num_classes = num_classes
        self.drop_rate = drop_rate
        wf = width_factor

        self.feature_info = []
        stem_chs = make_div(stem_chs * wf)
        self.stem = create_resnetv2_stem(
            in_chans, stem_chs, stem_type, preact, conv_layer=conv_layer, norm_layer=norm_layer)
        stem_feat = ('stem.conv3' if is_stem_deep(stem_type) else 'stem.conv') if preact else 'stem.norm'
        self.feature_info.append(dict(num_chs=stem_chs, reduction=2, module=stem_feat))

        prev_chs = stem_chs
        curr_stride = 4
        dilation = 1
        block_dprs = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(layers)).split(layers)]
        block_fn = PreActBottleneck if preact else Bottleneck
        self.stages = nn.Sequential()
        for stage_idx, (d, c, bdpr) in enumerate(zip(layers, channels, block_dprs)):
            out_chs = make_div(c * wf)
            stride = 1 if stage_idx == 0 else 2
            if curr_stride >= output_stride:
                dilation *= stride
                stride = 1
            stage = ResNetStage(
                prev_chs, out_chs, stride=stride, dilation=dilation, depth=d, avg_down=avg_down,
                act_layer=act_layer, conv_layer=conv_layer, norm_layer=norm_layer, block_dpr=bdpr, block_fn=block_fn)
            prev_chs = out_chs
            curr_stride *= stride
            self.feature_info += [dict(num_chs=prev_chs, reduction=curr_stride, module=f'stages.{stage_idx}')]
            self.stages.add_module(str(stage_idx), stage)

        self.num_features = prev_chs
        self.norm = norm_layer(self.num_features) if preact else nn.Identity()
        self.head = ClassifierHead(
            self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate, use_conv=True)

        self.init_weights(zero_init_last=zero_init_last)

    def init_weights(self, zero_init_last=True):
        named_apply(partial(_init_weights, zero_init_last=zero_init_last), self)

    @torch.jit.ignore()
    def load_pretrained(self, checkpoint_path, prefix='resnet/'):
        _load_weights(self, checkpoint_path, prefix)

    def get_classifier(self):
        return self.head.fc

    def reset_classifier(self, num_classes, global_pool='avg'):
        self.num_classes = num_classes
        self.head = ClassifierHead(
            self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate, use_conv=True)

    def forward_features(self, x):
        x = self.stem(x)
        x = self.stages(x)
        x = self.norm(x)
        return x

    def forward(self, x):
        x = self.forward_features(x)
        x = self.head(x)
        return x


def _init_weights(module: nn.Module, name: str = '', zero_init_last=True):
    if isinstance(module, nn.Linear) or ('head.fc' in name and isinstance(module, nn.Conv2d)):
        nn.init.normal_(module.weight, mean=0.0, std=0.01)
        nn.init.zeros_(module.bias)
    elif isinstance(module, nn.Conv2d):
        nn.init.kaiming_normal_(module.weight, mode='fan_out', nonlinearity='relu')
        if module.bias is not None:
            nn.init.zeros_(module.bias)
    elif isinstance(module, (nn.BatchNorm2d, nn.LayerNorm, nn.GroupNorm)):
        nn.init.ones_(module.weight)
        nn.init.zeros_(module.bias)
    elif zero_init_last and hasattr(module, 'zero_init_last'):
        module.zero_init_last()


@torch.no_grad()
def _load_weights(model: nn.Module, checkpoint_path: str, prefix: str = 'resnet/'):
    import numpy as np

    def t2p(conv_weights):
        """Possibly convert HWIO to OIHW."""
        if conv_weights.ndim == 4:
            conv_weights = conv_weights.transpose([3, 2, 0, 1])
        return torch.from_numpy(conv_weights)

    weights = np.load(checkpoint_path)
    stem_conv_w = adapt_input_conv(
        model.stem.conv.weight.shape[1], t2p(weights[f'{prefix}root_block/standardized_conv2d/kernel']))
    model.stem.conv.weight.copy_(stem_conv_w)
    model.norm.weight.copy_(t2p(weights[f'{prefix}group_norm/gamma']))
    model.norm.bias.copy_(t2p(weights[f'{prefix}group_norm/beta']))
    if isinstance(getattr(model.head, 'fc', None), nn.Conv2d) and \
            model.head.fc.weight.shape[0] == weights[f'{prefix}head/conv2d/kernel'].shape[-1]:
        model.head.fc.weight.copy_(t2p(weights[f'{prefix}head/conv2d/kernel']))
        model.head.fc.bias.copy_(t2p(weights[f'{prefix}head/conv2d/bias']))
    for i, (sname, stage) in enumerate(model.stages.named_children()):
        for j, (bname, block) in enumerate(stage.blocks.named_children()):
            cname = 'standardized_conv2d'
            block_prefix = f'{prefix}block{i + 1}/unit{j + 1:02d}/'
            block.conv1.weight.copy_(t2p(weights[f'{block_prefix}a/{cname}/kernel']))
            block.conv2.weight.copy_(t2p(weights[f'{block_prefix}b/{cname}/kernel']))
            block.conv3.weight.copy_(t2p(weights[f'{block_prefix}c/{cname}/kernel']))
            block.norm1.weight.copy_(t2p(weights[f'{block_prefix}a/group_norm/gamma']))
            block.norm2.weight.copy_(t2p(weights[f'{block_prefix}b/group_norm/gamma']))
            block.norm3.weight.copy_(t2p(weights[f'{block_prefix}c/group_norm/gamma']))
            block.norm1.bias.copy_(t2p(weights[f'{block_prefix}a/group_norm/beta']))
            block.norm2.bias.copy_(t2p(weights[f'{block_prefix}b/group_norm/beta']))
            block.norm3.bias.copy_(t2p(weights[f'{block_prefix}c/group_norm/beta']))
            if block.downsample is not None:
                w = weights[f'{block_prefix}a/proj/{cname}/kernel']
                block.downsample.conv.weight.copy_(t2p(w))


def _create_resnetv2(variant, pretrained=False, **kwargs):
    feature_cfg = dict(flatten_sequential=True)
    return build_model_with_cfg(
        ResNetV2, variant, pretrained,
        default_cfg=default_cfgs[variant],
        feature_cfg=feature_cfg,
        pretrained_custom_load=True,
        **kwargs)


def _create_resnetv2_bit(variant, pretrained=False, **kwargs):
    return _create_resnetv2(
        variant, pretrained=pretrained, stem_type='fixed',  conv_layer=partial(StdConv2d, eps=1e-8), **kwargs)


@register_model
def resnetv2_50x1_bitm(pretrained=False, **kwargs):
    return _create_resnetv2_bit(
        'resnetv2_50x1_bitm', pretrained=pretrained, layers=[3, 4, 6, 3], width_factor=1, **kwargs)


@register_model
def resnetv2_50x3_bitm(pretrained=False, **kwargs):
    return _create_resnetv2_bit(
        'resnetv2_50x3_bitm', pretrained=pretrained, layers=[3, 4, 6, 3], width_factor=3, **kwargs)


@register_model
def resnetv2_101x1_bitm(pretrained=False, **kwargs):
    return _create_resnetv2_bit(
        'resnetv2_101x1_bitm', pretrained=pretrained, layers=[3, 4, 23, 3], width_factor=1, **kwargs)


@register_model
def resnetv2_101x3_bitm(pretrained=False, **kwargs):
    return _create_resnetv2_bit(
        'resnetv2_101x3_bitm', pretrained=pretrained, layers=[3, 4, 23, 3], width_factor=3, **kwargs)


@register_model
def resnetv2_152x2_bitm(pretrained=False, **kwargs):
    return _create_resnetv2_bit(
        'resnetv2_152x2_bitm', pretrained=pretrained, layers=[3, 8, 36, 3], width_factor=2, **kwargs)


@register_model
def resnetv2_152x4_bitm(pretrained=False, **kwargs):
    return _create_resnetv2_bit(
        'resnetv2_152x4_bitm', pretrained=pretrained, layers=[3, 8, 36, 3], width_factor=4, **kwargs)


@register_model
def resnetv2_50x1_bitm_in21k(pretrained=False, **kwargs):
    return _create_resnetv2_bit(
        'resnetv2_50x1_bitm_in21k', pretrained=pretrained, num_classes=kwargs.pop('num_classes', 21843),
        layers=[3, 4, 6, 3], width_factor=1, **kwargs)


@register_model
def resnetv2_50x3_bitm_in21k(pretrained=False, **kwargs):
    return _create_resnetv2_bit(
        'resnetv2_50x3_bitm_in21k', pretrained=pretrained, num_classes=kwargs.pop('num_classes', 21843),
        layers=[3, 4, 6, 3], width_factor=3, **kwargs)


@register_model
def resnetv2_101x1_bitm_in21k(pretrained=False, **kwargs):
    return _create_resnetv2(
        'resnetv2_101x1_bitm_in21k', pretrained=pretrained, num_classes=kwargs.pop('num_classes', 21843),
        layers=[3, 4, 23, 3], width_factor=1, **kwargs)


@register_model
def resnetv2_101x3_bitm_in21k(pretrained=False, **kwargs):
    return _create_resnetv2_bit(
        'resnetv2_101x3_bitm_in21k', pretrained=pretrained, num_classes=kwargs.pop('num_classes', 21843),
        layers=[3, 4, 23, 3], width_factor=3, **kwargs)


@register_model
def resnetv2_152x2_bitm_in21k(pretrained=False, **kwargs):
    return _create_resnetv2_bit(
        'resnetv2_152x2_bitm_in21k', pretrained=pretrained, num_classes=kwargs.pop('num_classes', 21843),
        layers=[3, 8, 36, 3], width_factor=2, **kwargs)


@register_model
def resnetv2_152x4_bitm_in21k(pretrained=False, **kwargs):
    return _create_resnetv2_bit(
        'resnetv2_152x4_bitm_in21k', pretrained=pretrained, num_classes=kwargs.pop('num_classes', 21843),
        layers=[3, 8, 36, 3], width_factor=4, **kwargs)


@register_model
def resnetv2_50x1_bit_distilled(pretrained=False, **kwargs):
    """ ResNetV2-50x1-BiT Distilled
    Paper: Knowledge distillation: A good teacher is patient and consistent - https://arxiv.org/abs/2106.05237
    """
    return _create_resnetv2_bit(
        'resnetv2_50x1_bit_distilled', pretrained=pretrained, layers=[3, 4, 6, 3], width_factor=1, **kwargs)


@register_model
def resnetv2_152x2_bit_teacher(pretrained=False, **kwargs):
    """ ResNetV2-152x2-BiT Teacher
    Paper: Knowledge distillation: A good teacher is patient and consistent - https://arxiv.org/abs/2106.05237
    """
    return _create_resnetv2_bit(
        'resnetv2_152x2_bit_teacher', pretrained=pretrained, layers=[3, 8, 36, 3], width_factor=2, **kwargs)


@register_model
def resnetv2_152x2_bit_teacher_384(pretrained=False, **kwargs):
    """ ResNetV2-152xx-BiT Teacher @ 384x384
    Paper: Knowledge distillation: A good teacher is patient and consistent - https://arxiv.org/abs/2106.05237
    """
    return _create_resnetv2_bit(
        'resnetv2_152x2_bit_teacher_384', pretrained=pretrained, layers=[3, 8, 36, 3], width_factor=2, **kwargs)


@register_model
def resnetv2_50(pretrained=False, **kwargs):
    return _create_resnetv2(
        'resnetv2_50', pretrained=pretrained,
        layers=[3, 4, 6, 3], conv_layer=create_conv2d, norm_layer=BatchNormAct2d, **kwargs)


@register_model
def resnetv2_50d(pretrained=False, **kwargs):
    return _create_resnetv2(
        'resnetv2_50d', pretrained=pretrained,
        layers=[3, 4, 6, 3], conv_layer=create_conv2d, norm_layer=BatchNormAct2d,
        stem_type='deep', avg_down=True, **kwargs)


@register_model
def resnetv2_50t(pretrained=False, **kwargs):
    return _create_resnetv2(
        'resnetv2_50t', pretrained=pretrained,
        layers=[3, 4, 6, 3], conv_layer=create_conv2d, norm_layer=BatchNormAct2d,
        stem_type='tiered', avg_down=True, **kwargs)


@register_model
def resnetv2_101(pretrained=False, **kwargs):
    return _create_resnetv2(
        'resnetv2_101', pretrained=pretrained,
        layers=[3, 4, 23, 3], conv_layer=create_conv2d, norm_layer=BatchNormAct2d, **kwargs)


@register_model
def resnetv2_101d(pretrained=False, **kwargs):
    return _create_resnetv2(
        'resnetv2_101d', pretrained=pretrained,
        layers=[3, 4, 23, 3], conv_layer=create_conv2d, norm_layer=BatchNormAct2d,
        stem_type='deep', avg_down=True, **kwargs)


@register_model
def resnetv2_152(pretrained=False, **kwargs):
    return _create_resnetv2(
        'resnetv2_152', pretrained=pretrained,
        layers=[3, 8, 36, 3], conv_layer=create_conv2d, norm_layer=BatchNormAct2d, **kwargs)


@register_model
def resnetv2_152d(pretrained=False, **kwargs):
    return _create_resnetv2(
        'resnetv2_152d', pretrained=pretrained,
        layers=[3, 8, 36, 3], conv_layer=create_conv2d, norm_layer=BatchNormAct2d,
        stem_type='deep', avg_down=True, **kwargs)


# @register_model
# def resnetv2_50ebd(pretrained=False, **kwargs):
#     # FIXME for testing w/ TPU + PyTorch XLA
#     return _create_resnetv2(
#         'resnetv2_50d', pretrained=pretrained,
#         layers=[3, 4, 6, 3], conv_layer=create_conv2d, norm_layer=EvoNormBatch2d,
#         stem_type='deep', avg_down=True, **kwargs)
#
#
# @register_model
# def resnetv2_50esd(pretrained=False, **kwargs):
#     # FIXME for testing w/ TPU + PyTorch XLA
#     return _create_resnetv2(
#         'resnetv2_50d', pretrained=pretrained,
#         layers=[3, 4, 6, 3], conv_layer=create_conv2d, norm_layer=EvoNormSample2d,
#         stem_type='deep', avg_down=True, **kwargs)
